# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import logging
from dataclasses import asdict
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn

from xformers._deprecation_warning import deprecated_function
from xformers.components import (
    PatchEmbeddingConfig,
    PostNorm,
    PreNorm,
    Residual,
    ResidualNormStyle,
    build_multi_head_attention,
    build_patch_embedding,
)
from xformers.components.attention import AttentionMask
from xformers.components.feedforward import build_feedforward
from xformers.components.positional_embedding import build_positional_embedding
from xformers.components.residual import get_deepnorm_coefficients
from xformers.components.simplicial_embedding import SimplicialEmbedding
from xformers.factory.block_configs import (
    NormalizationType,
    xFormerDecoderConfig,
    xFormerEncoderConfig,
)

logger = logging.getLogger("xformers")


def _get_ln_factory(
    d_model: int,
    residual_norm_style: Optional[ResidualNormStyle],
    use_triton: bool,
    residual: bool,
    normalization: NormalizationType = NormalizationType.LayerNorm,
    residual_scale: float = 1.0,
):
    """
    Handle all the supported residual path configurations.

    ..Note: we return the appropriate constructor, not an actual layer
    """

    def get_layer_wrapper(
        d_model: int,
        sublayer: nn.Module,
        residual_norm_style: Optional[ResidualNormStyle],
        residual: bool,
        residual_scale: float,
    ):
        if residual:
            if residual_norm_style == ResidualNormStyle.Pre:
                return Residual(
                    layer=PreNorm(d_model, sublayer, normalization, use_triton),
                    scale=None,
                )
            elif residual_norm_style == ResidualNormStyle.Post:
                return PostNorm(
                    d_model,
                    Residual(layer=sublayer, scale=None),
                    normalization,
                    use_triton,
                )
            elif residual_norm_style == ResidualNormStyle.DeepNorm:
                return PostNorm(
                    d_model,
                    Residual(layer=sublayer, scale=residual_scale),
                    normalization,
                    use_triton=use_triton,
                )
            else:
                raise ValueError

        return (
            PreNorm(d_model, sublayer, normalization, use_triton)
            if residual_norm_style == ResidualNormStyle.Pre
            else PostNorm(d_model, sublayer, normalization, use_triton)
        )

    def ln_factory(sublayer: nn.Module):
        return get_layer_wrapper(
            d_model, sublayer, residual_norm_style, residual, residual_scale
        )

    return ln_factory


class xFormerEncoderBlock(torch.nn.Module):
    r"""A vanilla Transformer Encoder block"""

    def __init__(self, config: xFormerEncoderConfig, **kwargs):
        super().__init__()
        deprecated_function(self)

        self.reversible_f = None
        self.reversible_g = None
        self.residual_norm_style = config.residual_norm_style
        self.dim_model = config.dim_model

        # If this layer is the first one, and a pose encoding has been requested
        if (
            config.position_encoding_config is not None
            and config.layer_position.is_first()
        ):
            self.pose_encoding = build_positional_embedding(
                asdict(config.position_encoding_config)
            )

            pos_encoding_dim = config.position_encoding_config.dim_model
            mha_dim = config.multi_head_config["dim_model"]

            if pos_encoding_dim != mha_dim:
                logger.warning(
                    f"The embedding dim and model dim do not match ({pos_encoding_dim} vs {mha_dim}), adding a projector layer."  # noqa
                )
                self.embedding_projector = nn.Linear(pos_encoding_dim, mha_dim)
        else:
            self.pose_encoding = None

        if config.residual_norm_style == ResidualNormStyle.DeepNorm:
            # Just use the layer norm coefficient here,
            # the init will be handled at the xformers level (knows about encoder and decoder blocks)
            deep_norm_coefficients, _ = get_deepnorm_coefficients(
                encoder_layers=config.num_layers, decoder_layers=0
            )
            assert deep_norm_coefficients is not None
            residual_scale = deep_norm_coefficients.alpha
        else:
            residual_scale = 1.0

        # mini helper, builds a normalization layer with the right Pre/Post config, residuals, and the right dimensions
        ln_factory = _get_ln_factory(
            config.dim_model,
            config.residual_norm_style,
            use_triton=config.use_triton,
            residual=True,
            residual_scale=residual_scale,
            normalization=config.normalization,
        )

        mha = build_multi_head_attention(config.multi_head_config)
        feedforward = build_feedforward(asdict(config.feedforward_config))

        # Expose attention specific capabilities
        self.supports_attention_mask = mha.attention.supports_attention_mask
        self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
        self.causal = (
            mha.attention.causal if hasattr(mha.attention, "causal") else False
        )

        # Wrappers handle the different layer norm styles (pre- and post-) and the residual path
        self.wrap_att = ln_factory(mha)
        self.wrap_ff: Union[Residual, PostNorm] = ln_factory(feedforward)
        if (
            config.residual_norm_style == ResidualNormStyle.Pre
            and config.layer_position.is_last()
        ):
            self.wrap_ff = PostNorm(
                config.dim_model,
                self.wrap_ff,
                normalization=config.normalization,
                use_triton=config.use_triton,
            )

        # Simplicial embeddings are only used if specified, and on the last layer
        self.simplicial_embedding: Optional[SimplicialEmbedding] = None
        if config.simplicial_embeddings is not None and config.layer_position.is_last():
            self.simplicial_embedding = SimplicialEmbedding(
                **config.simplicial_embeddings
            )

        # Optional patch embedding
        self.patch_emb: Optional[nn.Module] = None

        if config.patch_embedding_config is not None:
            self.patch_emb = build_patch_embedding(
                PatchEmbeddingConfig(**config.patch_embedding_config)
            )

    @classmethod
    def from_config(cls, config: xFormerEncoderConfig):
        return cls(config)

    @staticmethod
    def get_reversible_layer(config) -> Tuple[nn.Module, nn.Module]:
        ln_factory = _get_ln_factory(
            config.dim_model,
            config.residual_norm_style,
            residual=False,
            use_triton=config.use_triton,
            normalization=config.normalization,
        )

        mha = build_multi_head_attention(config.multi_head_config)
        feedforward = build_feedforward(asdict(config.feedforward_config))

        reversible_f = ln_factory(mha)
        reversible_g = ln_factory(feedforward)
        return reversible_f, reversible_g

    def forward(
        self,
        x: torch.Tensor,
        att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
        input_mask: Optional[torch.Tensor] = None,
    ):
        if self.patch_emb is not None:
            x = self.patch_emb(x)

        if self.pose_encoding is not None:
            x = self.pose_encoding(x)

            if hasattr(self, "embedding_projector"):
                x = self.embedding_projector(x)

        # Handle the optional input masking, differs on Q, K, V
        if input_mask is not None:
            q = x
            k = x * input_mask.unsqueeze(-1)
            v = k
        else:
            q, k, v = x, x, x

        # Pre/Post norms and residual paths are already handled
        x = self.wrap_att(inputs=[q, k, v], att_mask=att_mask)
        x = self.wrap_ff(inputs=[x])

        # Optional simplicial embeddings
        if self.simplicial_embedding is not None:
            x = self.simplicial_embedding(x)

        return x


class xFormerDecoderBlock(torch.nn.Module):
    r"""A vanilla Transformer Decoder block

    ... note: this implementation is not (yet ?) reversible"""

    def __init__(self, config: xFormerDecoderConfig, **kwargs):
        super().__init__()
        deprecated_function(self)

        # If this layer is the first one, and a pose encoding as been requested
        if (
            config.position_encoding_config is not None
            and config.layer_position.is_first()
        ):
            self.pose_encoding = build_positional_embedding(
                config.position_encoding_config
            )

            pos_encoding_dim = config.position_encoding_config.dim_model
            mha_dim = config.multi_head_config_masked["dim_model"]

            if pos_encoding_dim != mha_dim:

                logger.warning(
                    f"The embedding dim and model dim do not match ({pos_encoding_dim} vs {mha_dim}), adding a projector layer."  # noqa
                )

                self.embedding_projector = nn.Linear(pos_encoding_dim, mha_dim)
        else:
            self.pose_encoding = None

        if config.residual_norm_style == ResidualNormStyle.DeepNorm:
            # Just use the layer norm coefficient here,
            # the init will be handled at the xformers level (knows about encoder and decoder blocks)
            _, deep_norm_coefficients = get_deepnorm_coefficients(
                encoder_layers=0, decoder_layers=config.num_layers
            )
            assert deep_norm_coefficients is not None
            residual_scale = deep_norm_coefficients.alpha
        else:
            residual_scale = 1.0

        # mini helper, builds a LayerNorm with the right Pre/Post config and the right dimensions
        ln_factory = _get_ln_factory(
            config.dim_model,
            config.residual_norm_style,
            use_triton=config.use_triton,
            residual=True,
            residual_scale=residual_scale,
            normalization=config.normalization,
        )

        mha = build_multi_head_attention(config.multi_head_config_masked)
        cross_mha = build_multi_head_attention(config.multi_head_config_cross)
        feedforward = build_feedforward(config.feedforward_config)

        # Expose attention or feedforward specific capabilities
        self.supports_attention_mask = mha.attention.supports_attention_mask
        self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions
        self.requires_squared_context_length = (
            feedforward.requires_squared_context
            or mha.attention.requires_squared_context
        )

        self.causal_attention = (
            mha.attention.causal if hasattr(mha.attention, "causal") else False
        )

        # Wrappers handle the different layer norm styles (pre- and post-) and the residual path
        self.wrap_att = ln_factory(mha)
        self.wrap_cross = ln_factory(cross_mha)
        self.wrap_ff: Union[Residual, PostNorm] = ln_factory(feedforward)

        if (
            config.residual_norm_style == ResidualNormStyle.Pre
            and config.layer_position.is_last()
        ):
            self.wrap_ff = PostNorm(
                config.dim_model,
                self.wrap_ff,
                normalization=NormalizationType.LayerNorm,
            )

    @classmethod
    def from_config(cls, config: xFormerDecoderConfig):
        return cls(config)

    def forward(
        self,
        target: torch.Tensor,
        memory: torch.Tensor,
        encoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
        decoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None,
        input_mask: Optional[torch.Tensor] = None,
    ):
        if self.pose_encoding is not None:
            target = self.pose_encoding(target)

            if hasattr(self, "embedding_projector"):
                target = self.embedding_projector(target)

        # Handle the optional input masking, differs on Q, K, V
        if input_mask is not None:
            target_q = target
            target_k = target * input_mask.unsqueeze(-1)
            target_v = target_k
        else:
            target_q, target_k, target_v = target, target, target

        x = self.wrap_att(
            inputs=[target_q, target_k, target_v], att_mask=decoder_att_mask
        )
        x = self.wrap_cross(inputs=[x, memory, memory], att_mask=encoder_att_mask)
        x = self.wrap_ff(inputs=[x])

        return x
